#################################################### 
#Author: Kelli Marquardt
#Purpose: Compute the ADHD match values based on patient note text 

# Inputs:
#- data/intermediate/DSM_Vectors.csv
#- data/intermediate/pat_notes_fake_Vectors.csv
#- data/intermediate/note_dat_fake_wQ.csv

# Outputs:
#- data/intermediate/nlp_match_dat_fake.csv
#- output/tables/tab_c1.txt

####################################################


############################
#0 load required packages
############################
rm(list = ls(all.names = TRUE))

#load packages  
library(parallel)
library(parallelMap)
library(pbapply) #for apply function progress bar 
library(stringr) #for str_detect
library(data.table) #fread and fwrite
library(dplyr)
library(tidytext) #for unnest tokens
library(tm) #for less restrictive set of stop words 
library(tidyr) #for pivot wider 




#################################################
#Step1: Read in DSM vector and define a stem to stem mapping 
#################################################

#read in DSM Vector 
DSM_Vector = fread(file.path("..", "data", "intermediate", "DSM_Vectors.csv"), stringsAsFactors = FALSE)

#create stem to stem mapping for both close_stem and synonym_stem
#Goal: if patient stem is in DSM stem, keep 
  #else, if patient stem in symptom stem or close stem, replace with the associated DSM stem 
  #ok to keep multiple mappings, will assign to whichever the note has less of to ensure variation 

dsm_xwalk=list()

dsm_types=unique(DSM_Vector$type)
for(t in dsm_types){
  
  dsm_replace_close=DSM_Vector%>%
    filter(type==t)%>%
    select(stem, close_stem)%>%
    distinct()
  dsm_replace_syn=DSM_Vector%>%
    filter(type==t)%>%
    select(stem, synonym_stem)%>%
    distinct()
  
  #if close_stem or synonym_stem is already a baseline stem, drop 
  dsm_replace_close=dsm_replace_close%>%
    filter(!(close_stem  %in% unique(dsm_replace_close$stem)))
  dsm_replace_syn=dsm_replace_syn%>%
    filter(!(synonym_stem  %in% unique(dsm_replace_syn$stem)))
  
  #rename columns and bring together 
  colnames(dsm_replace_close)=c("stem_base", "stem_search")
  colnames(dsm_replace_syn)=c("stem_base", "stem_search")
  dsm_replace=rbind(dsm_replace_close, dsm_replace_syn)
  
  #if close_stem or synonym_stem maps to multiple stems, reshape wide 
  dsm_replace=dsm_replace%>%
    group_by(stem_search)%>%
    mutate(stem_num=paste0("stem",row_number()))%>%
    ungroup()%>%
    pivot_wider(names_from = "stem_num", values_from = "stem_base")

  #assign to the crosswalk
  dsm_xwalk[[t]]=dsm_replace

  rm(dsm_replace_close, dsm_replace_syn, dsm_replace)
}
#note that type1 has 3 stem options, type2 has 2 stem options


#################################################
#Step 2: read in the assigned ML prediction (note_dat_fake_wQ.csv), 
#and define inclusion window variable
#################################################
ml_notes = read.csv(file.path("..", "data", "intermediate", "note_dat_fake_wQ.csv"),
                    stringsAsFactors = FALSE)

#first, drop all those visits that do not have assigned_label==1 or prob1>0.5 (only want behavioral visits in determining ADHD match)
nlp_dat=ml_notes%>%
  filter(assigned_label==1 | prob1>0.5)

#then determine days since dx, for the set of patients with an adhd dx 
nlp_dat=nlp_dat%>%
  group_by(pat_id)%>%
  mutate(adhd_ever=max(adhd_dx))%>%
  ungroup()%>%
  group_by(pat_id, adhd_dx)%>%
  mutate(dx_age=min(age)*adhd_dx)%>%
  ungroup()%>%
  group_by(pat_id)%>%
  mutate(dx_age=max(dx_age))%>%
  ungroup()%>%
  mutate(days_since_dx=adhd_ever*(age-dx_age)*365)

#next define includeNLP label 
  #_all: all behavioral appt 
  #_01: all behavioral appt up to and including the first adhd dx  
  #_30: all behavioral appt up to and including those 30 days after the first adhd dx  
  #_60: all behavioral appt up to and including those 60 days after the first adhd dx  

nlp_dat=nlp_dat%>%
  mutate(includeNLP_all=1, 
         includeNLP_01=case_when(adhd_ever==0 ~ 1, 
                                 adhd_ever==1 & days_since_dx<=0 ~ 1, 
                                 T ~ 0),
         includeNLP_30=case_when(adhd_ever==0 ~ 1, 
                                 adhd_ever==1 & days_since_dx<=30 ~ 1, 
                                 T ~ 0),
         includeNLP_60=case_when(adhd_ever==0 ~ 1, 
                                 adhd_ever==1 & days_since_dx<=60 ~ 1, 
                                 T ~ 0))

#only keep pat_id and visit_id with at least one includeNLP_*=1 (should be all t)
nlp_dat=nlp_dat%>%
  mutate(ever_include=includeNLP_all+includeNLP_01+includeNLP_30+includeNLP_60)%>%
  select(pat_id, visit_id, starts_with("includeNLP"))



#################################################
#Step 3: read in the clean patient notes vector 
#################################################

notes_dta_cleaned = fread(file.path("..", "data", "intermediate", "pat_notes_fake_Vectors.csv"),
                          stringsAsFactors = FALSE)

#################################################
#Step 4: write a function that returns the tokenized data (after stem replacement) for a given inclusion window and adhd type 
#################################################


GetTokenDat=function(includeNLP_def,dsm_type_num){
  
  if(!(includeNLP_def %in% c("includeNLP_all", "includeNLP_01","includeNLP_30", "includeNLP_60"))){
    return("includeNLP_def not valid")
  }
  if(!(dsm_type_num %in% dsm_types)){
    return("dsm_type_num not valid")
  }
  
  #filter to notes that have the includeNLP_def label
  Notes_Vector=nlp_dat%>%
    left_join(notes_dta_cleaned, by=c("pat_id","visit_id"))%>%
    rename(includeNLP=!!sym(paste0(includeNLP_def)))%>%
    filter(includeNLP==1)%>%
    select(pat_id, visit_id, stem)
  
  
  #replace note-stems with dsm-stems using dsm_xwalk[[dsm_type_num]]
  dsm_xwalk_t=dsm_xwalk[[dsm_type_num]]
  
  Notes_Vector=Notes_Vector%>%
    left_join(dsm_xwalk_t, by=c("stem"="stem_search"))
  #note- 3 matches for type 1 but only 2 with type 2. Add NA for stem3 if dsm_type_num==2
  if(dsm_type_num==2){
    Notes_Vector$stem3=NA_character_
  }
 

  #Processing those with multiple matches 
  #if stem1=NA, do nothing
  #stem2=NA, replace stem with stem1
  #if stem2!=NA, count number of times stem1, stem2, and stem3 appear in the patient note, replace with the minimum (break ties with first one)
  
    #get full note which will be used to search how often stems appear
  full_note = Notes_Vector %>%
    group_by(pat_id) %>%
    summarise(full_note = paste(stem, collapse = " "), .groups = "drop")
    
  #for those with multiple stem matches, join with full_note, count how often they appear, return the minimum 
  Notes_Vector_multStem=Notes_Vector%>%
    filter(!is.na(stem2))%>%
    left_join(full_note, by="pat_id")%>%
    rowwise()%>%
    mutate(num_stem1=str_count(full_note, stem1),
           num_stem2=str_count(full_note, stem2),
           num_stem3=if_else(is.na(stem3), NA, str_count(full_note, stem3)),
           min_stem_num=pmin(num_stem1, num_stem2, num_stem3, na.rm=T))%>%
    mutate(stem_to_replace=case_when(min_stem_num==num_stem1 ~ stem1, 
                                     min_stem_num==num_stem2 ~ stem2, 
                                     T ~ stem3))%>%
             ungroup()%>%
    select(pat_id, visit_id, stem, stem1, stem2, stem3, stem_to_replace)%>%
    distinct()
            
  #bring back in to the full set 
  Notes_Vector=Notes_Vector%>%
    left_join(Notes_Vector_multStem, by=c("pat_id","visit_id","stem","stem1","stem2", "stem3"))%>%
    mutate(stem_replace=case_when(is.na(stem1) ~ stem, 
                                  is.na(stem2)~  stem1, 
                                  !is.na(stem2) ~ stem_to_replace))%>%
    select(pat_id, visit_id, stem_replace)%>%
    rename(stem=stem_replace)
  
  #clean up 
  rm(Notes_Vector_multStem, full_note)
  
  
  #create patient group id 
  Notes_Vector=Notes_Vector%>%
    group_by(pat_id)%>%
    mutate(pat_num=cur_group_id())%>%
    ungroup()
  
  #save crosswalk to return later
  pat_id_xwalk=Notes_Vector%>%
    mutate(document=pat_num)%>%
    select(pat_id, document)%>%
    distinct()
  
  
  ##combine unigrams and bigrams 
  notes_uni=Notes_Vector%>%
    mutate(document=pat_num, word=stem)%>%
    select(document,word)
  notes_bi=notes_uni %>% 
    group_by(document)%>%
    summarize(combined = str_flatten(word, " "))%>%
    unnest_tokens(word, combined, token = "ngrams", n=2)%>%
    ungroup()
  notes_tokens=rbind(notes_uni,notes_bi)
  rm(notes_uni,notes_bi)
  
  num_pat_docs=max(Notes_Vector$pat_num)
  
  ##And for dsm vector (document=max(number of patients)+1)
  dsm_uni=DSM_Vector%>%
    filter(type==dsm_type_num)%>%
    mutate(document=num_pat_docs+1, word=stem)%>%
    select(document,word)

  dsm_bi=dsm_uni %>% 
    group_by(document)%>%
    summarize(combined = str_flatten(word, " "))%>%
    unnest_tokens(word, combined, token = "ngrams", n=2)%>%
    ungroup()
  dsm_tokens=rbind(dsm_uni,dsm_bi)
  rm(dsm_uni,dsm_bi, num_pat_docs)
  
  
  ### clean up 
  rm(dsm_xwalk_t, Notes_Vector)
  
  ### return the list of note_tokens,  dsm_tokens, and patient to document crosswalk 
  return(list(notes_tokens=notes_tokens, 
              dsm_tokens=dsm_tokens, 
              pat_id_xwalk=pat_id_xwalk))
}




#################################################
#Step 4: write a function that returns the cosine similarity for each patient given tokens
#################################################

GetCosineMatch=function(token_output){
  
  #start by pulling out the elements of token output
  notes_tokens=token_output$notes_tokens
  dsm_tokens=token_output$dsm_tokens
  pat_id_xwalk=token_output$pat_id_xwalk
  
  ### Get tf-idf for just the patient notes 
  #start with tf-
  notes_tokens=notes_tokens%>%
    count(document, word, name="word_count")%>%
    group_by(document)%>%
    mutate(num_words=sum(word_count))%>%
    ungroup()%>%
    mutate(tf=word_count/num_words)%>%
    select(-c(word_count, num_words))
  #then idf
  notes_tokens=notes_tokens%>%
    mutate(num_docs_total=n_distinct(document))%>%
    group_by(word)%>%
    mutate(num_docs_word=n_distinct(document))%>%
    ungroup()%>%
    mutate(idf=log(num_docs_total/num_docs_word)+1)%>%
    select(-c(num_docs_total, num_docs_word))
  
  ### 
  #add tf and idf to dsm_tokens so can append with notes_tokens, but they get weight 1 
  dsm_tokens=dsm_tokens%>%
    mutate(tf= 1, 
           idf=1)
  
  ### Now combine into doc_tokens
  doc_tokens=dsm_tokens%>%
    rbind(notes_tokens)%>%
    arrange(document)
  
  #define tf-idf, keep only unique document, word, tf-idf 
  doc_tokens=doc_tokens%>%
    mutate(tf_idf=tf*idf)%>%
    select(document, word, tf_idf)%>%
    distinct()%>%
    arrange(document)
  
  ### Get Cosine for each patient with dsm doc
  #cast to document-term-matrix
  doc_dtm=as.matrix(cast_dtm(doc_tokens,word,document,tf_idf))
  
  dsm_vec=doc_dtm[, ncol(doc_dtm)] #reference column 
  vec_dotprod=crossprod(doc_dtm, dsm_vec) #dot product of each column with v
  mag_dsm=sqrt(sum(dsm_vec^2))  #magnitude of dsm vec
  mag_cols=sqrt(colSums(doc_dtm^2)) #magnitude of each column 
  cosine_match= as.matrix(t(as.numeric(vec_dotprod) / (mag_cols * mag_dsm)))
  

  cosine_match=rbind(cosine_match, as.numeric(colnames(doc_dtm)))
  cosine_match=as.data.frame(t(cosine_match))
  colnames(cosine_match)=c("cosine_match","document")
  
  #also want to return the cosine match with itself just to check 
  cosine_self_check=cosine_match[nrow(cosine_match),]
  
  #merge back in the pat_id based on pat_id_xwalk 
  cosine_match=pat_id_xwalk%>%
    left_join(cosine_match, by="document")
  
  #clean up and return cosine_match 
  rm(doc_dtm, notes_tokens, dsm_tokens, pat_id_xwalk)
  return(list(cosine_match_dat=cosine_match, cosine_self_check=cosine_self_check))
}

#################################################
#Step 5: Prep dataset and fill in with match values accordingly 
  #different dataset for each includeNLP option 
  #matches values for both types 
#################################################

label_types=c("includeNLP_all", "includeNLP_01","includeNLP_30", "includeNLP_60")
dsm_types=dsm_types


#for a given dsm_type, loop over label types 
  #get the relevant tokens 
  #send to parallel for cosine matches [only needed for large text data]
  #save output for the dsm_type 
gc()

#define dataset to hold matches for the given type and lab
outer_dat=NA

for(d in dsm_types){ 
  #loop over label options and dsm_only options 
  for(lab in label_types){
  
  #get the tokens 
  token_input=GetTokenDat(lab, d )
 
  #get the cosine match
  inner_dat=GetCosineMatch(token_input)
  
  #check the cosine_self_check value is 1 (or close up to rounding)
  if(abs(inner_dat$cosine_self_check["cosine_match"]-1)>1e-8){
    stop("Cosine match with DSM vector is not 1. Error in code above")
  }

  #combine results into single dataframe 
  inner_dat=inner_dat$cosine_match_dat
  inner_dat=inner_dat%>%
    mutate(lab=lab, dsm_type=d)
  
  #combine with outer loop 
  outer_dat=rbind(outer_dat, inner_dat)
      
      #clean up 
      rm(token_input, inner_dat)
    } #go on to next lab
} # end loop

#remove initialize row
outer_dat=outer_dat%>%
    filter(!is.na(pat_id))

#################################################
#Step 6: clean and output a final match dataset
#################################################
#varnames: pat_id, Qi, xi, type2_max, xi_01, xi_30, xi_60, xi_all, note_length


#first get Qi (max across visits)
pat_qi=ml_notes%>%
  mutate(qij=as.integer((!is.na(assigned_label) & assigned_label == 1) | prob1 > 0.5))%>%
  group_by(pat_id)%>%
  summarise(Qi=max(qij))%>%
  ungroup()


#within each lab-type, rescale the cosine match to be between 0 and 1 (subtract min and divide by range)
match_dat=outer_dat%>%
  group_by(lab, dsm_type)%>%
  mutate(match_min=min(cosine_match),
         match_max=max(cosine_match))%>%
  ungroup()%>%
  mutate(cosine_match=(cosine_match-match_min)/(match_max-match_min))%>%
  select(-c(match_min, match_max))

#next get xi_* which is max of match for type1, type2, given inclusion window * 
match_dat=match_dat%>%
  mutate(dsm_type = as.character(dsm_type)) %>%  
  select(pat_id, lab, dsm_type, cosine_match) %>%
  pivot_wider(
    id_cols     = c(pat_id, lab),
    names_from  = dsm_type,
    values_from = cosine_match,
    names_prefix = "cosine_match_"
  )%>%
  mutate(xi=pmax(cosine_match_1, cosine_match_2),
         type2_max=as.integer(xi==cosine_match_2))%>%
  select(-c(cosine_match_1, cosine_match_2))

#reshape wide again so one row per patient with inclusion window suffix 
match_dat=match_dat %>%
  mutate(
    lab = sub("^includeNLP_", "", lab)  # all, 01, 30, 60
  ) %>%
  pivot_wider(
    id_cols = pat_id,
    names_from = lab,
    values_from = c(xi, type2_max),
    names_glue = "{.value}_{lab}"
  )

#for baseline specification, the inclusion window is 60. 
  #define xi=xi_60 and type2_max=type2_max_60 
  #drop the other type_2s 
match_dat=match_dat%>%
  mutate(xi=xi_60, 
         type2_max=type2_max_60)%>%
  select(pat_id, xi, type2_max, xi_all, xi_01, xi_30, xi_60)


#also pull note length for the baseline spec 
note_length_dat=nlp_dat%>%
  left_join(notes_dta_cleaned, by=c("pat_id","visit_id"))%>%
  filter(includeNLP_60==1)%>%
  group_by(pat_id)%>%
  summarise(note_length=n())

#put together match_dat, note_length_dat, and  pat_qi
nlp_return_dat=pat_qi%>%
  left_join(note_length_dat, by="pat_id")%>%
  left_join(match_dat, by="pat_id")

#save 
write.csv(nlp_return_dat,
          file.path("..", "data", "intermediate", "nlp_match_dat_fake.csv"),
          row.names = FALSE)

#################################################
#Step 7: get the most predictive words for those with adhd dx and those without (for table c1)
#################################################


#determine diagnosis status (and male indicator)
dx_status=ml_notes%>%
  group_by(pat_id)%>%
  summarise(adhd_ever=max(adhd_dx),
            male=max(male))%>%
  ungroup()%>%
  select(pat_id, adhd_ever, male)
  
#determine high low xi* status ...
top_tercile_x=quantile(nlp_return_dat$xi, na.rm = T, p=2/3)
hi_xi_status=nlp_return_dat%>%
  mutate(hi_xi=as.integer(xi>=top_tercile_x))%>%
  select(pat_id, hi_xi)


#get unique words for each patient, merge with dx_status and hi_xi_status

#first, subset to notes with baseline inclusion 
predictive_words_dat=nlp_dat%>%
  left_join(notes_dta_cleaned, by=c("pat_id","visit_id"))%>%
  filter(includeNLP_60==1)%>%
  select(pat_id, stem)%>%
  distinct()%>%
  left_join(dx_status, by="pat_id")%>%
  left_join(hi_xi_status, by="pat_id")



#remove stop words
predictive_words_dat=predictive_words_dat%>%
  anti_join(stop_words, by=c("stem"="word"))

#write a function to return the most predictive words
get_top_words=function(dataset, thresh){
  top_words = dataset%>%select(pat_id,stem)%>%
    group_by(stem)%>%mutate(top_words=n())%>%
    select(stem, top_words)%>%
    distinct()%>%
    ungroup()%>%
    arrange(desc(top_words))%>%
    mutate(tf=top_words/sum(top_words))
  top_words_T=quantile(top_words$tf,thresh)
  top_words=top_words%>%
    filter(tf>=top_words_T)%>%
    arrange(desc(tf))
  return(top_words$stem)
}

#print top words overall, male, female that are not in the top 5% of the converse
top_pct=.95 #search for words in the top "top_pct" of a given set
n_word_find=5 #number of words to print in table 

#### for those with adhd dx 
t=predictive_words_dat%>%filter(adhd_ever==0)%>%get_top_words(top_pct)
adhd_words_overall=predictive_words_dat%>%filter(adhd_ever==1)%>%filter(!(stem %in% t))%>%get_top_words(top_pct)%>%head(n_word_find)

#what predicts adhd diagnosis for males?
t=predictive_words_dat%>%filter(adhd_ever==0 & male==1)%>%get_top_words(top_pct)
adhd_words_male=predictive_words_dat%>%filter(adhd_ever==1 & male==1)%>%filter(!(stem %in% t))%>%get_top_words(top_pct)%>%head(n_word_find)

#what predicts adhd diagnosis for females?
t=predictive_words_dat%>%filter(adhd_ever==0 & male==0)%>%get_top_words(top_pct)
adhd_words_female=predictive_words_dat%>%filter(adhd_ever==1 & male==0)%>%filter(!(stem %in% t))%>%get_top_words(top_pct)%>%head(n_word_find)


#### for those with high xi 
t=predictive_words_dat%>%filter(hi_xi==0)%>%get_top_words(top_pct)
hixi_words_overall=predictive_words_dat%>%filter(hi_xi==1)%>%filter(!(stem %in% t))%>%get_top_words(top_pct)%>%head(n_word_find)

#what predicts high xi  for males?
t=predictive_words_dat%>%filter(hi_xi==0 & male==1)%>%get_top_words(top_pct)
hixi_words_male=predictive_words_dat%>%filter(hi_xi==1 & male==1)%>%filter(!(stem %in% t))%>%get_top_words(top_pct)%>%head(n_word_find)

#what predicts high xi  for females?
t=predictive_words_dat%>%filter(hi_xi==0 & male==0)%>%get_top_words(top_pct)
hixi_words_female=predictive_words_dat%>%filter(hi_xi==1 & male==0)%>%filter(!(stem %in% t))%>%get_top_words(top_pct)%>%head(n_word_find)

### print the table c1 
top_words_tab =paste(
 "\\begin{tabular}{ll} \n",
 "\\toprule \n",
  "\\multicolumn{2}{l}{\\textbf{Patients with ADHD Diagnosis} } \\\\ \n",
  sprintf("\\hspace{3mm} Overall & \\textit{%s} \\\\ \n",
          paste(adhd_words_overall, collapse = ", ")),
 sprintf("\\hspace{3mm} Male & \\textit{%s} \\\\ \n",
         paste(adhd_words_male, collapse = ", ")),
 sprintf("\\hspace{3mm} Female & \\textit{%s} \\\\ \n",
         paste(adhd_words_female, collapse = ", ")),
 "\\hline \n",
 "\\addlinespace \n",
 "\\multicolumn{2}{l}{\\textbf{Patients with High ADHD Match} } \\\\ \n",
 sprintf("\\hspace{3mm} Overall & \\textit{%s} \\\\ \n",
         paste(hixi_words_overall, collapse = ", ")),
 sprintf("\\hspace{3mm} Male & \\textit{%s} \\\\ \n",
         paste(hixi_words_male, collapse = ", ")),
 sprintf("\\hspace{3mm} Female & \\textit{%s} \\\\ \n",
         paste(hixi_words_female, collapse = ", ")),
 "\\hline \n",
  "\\bottomrule \n",
  "\\end{tabular} \n") 

 
 
# save 
write(top_words_tab, file.path("..", "output", "tables", "tab_c1.txt"))

#END OF SCRIPT
